#!/usr/bin/env python

from PIL import Image

import cv2
import numpy as np
import matplotlib.pylab  as plt
from skimage import io
from skimage import morphology,draw
from skimage import measure,color
import time
from sklearn.cluster import MeanShift
from posProcess import posGeter

try:
    import numpy as np    
except ImportError:
    raise RuntimeError('cannot import numpy, make sure numpy package is installed')

try:
    from auto.control import controller
except ImportError:
    from control import controller


class autoAgent():
    def __init__(self,processer,controller):
        print('auto Agent start')
        self.controller = None
        self.error = 0
        self.world = None
        self.processer = processer
        self.ex = False
        self.showPic = False
        self.logging = True
        self.savePic = False
        self.showTimeBool = False
        self.picProcess = False
        self.controllProcesser = controller
        #self.controllProcesser.setRun(0)
        self.cv2hough = True
        self.picShape = [0,0]
        self.posGetter = posGeter()

    def outside(self,x,y):
        #print("x:%d y:%d shape:%s"%(x,y,self.picShape))
        if x < 0 or y < 0:
            return True
        if x > self.picShape[1] or y > self.picShape[0]:
            return True
        return False

    def pyrDown(self,image):
        for i in range(1):
            image = cv2.pyrDown(image)

        return image

    def toushiAndPyr(self,pic,gray = False):
        print('autoAgent get toushiAndPyr TU')
        image = Image.fromarray(pic)
        image = self.toushiTrans(np.asarray(image),show=False,save=False)
        image = self.pyrDown(image)
        if gray:
            image = self.color2gray(image, show=False, save=False)
        return image

    def erzhiTu(self,pic,bianxian = False):#bianxian 真则只提供边线，否则提供边线和背景
        print('autoAgent get erzhi TU')
        image = Image.fromarray(pic)
        image = self.toushiTrans(np.asarray(image),show=False,save=False)
        image = self.color2gray(image, show=False, save=False)
        image = self.pyrDown(image)
        if bianxian:
            image = (image > 230) * 1
        else:
            image = (image > 230) * 1 + (image == 0) * 1

        erzhiImg = np.uint8(image*255)
        return erzhiImg

    def dilate(self,image):
        kernelBand = 15
        kernel = np.ones((kernelBand, kernelBand), np.uint8)
        image = cv2.dilate(image, kernel)
        image = cv2.erode(image, kernel)
        return image

    def cheweiMean(self,pic):
        image = self.erzhiTu(pic)
        oriImage = self.toushiAndPyr(pic)
        image = self.dilate(image)

        labels = measure.label(image, connectivity=2,background=255)

        print('regions number:', labels.max() + 1)  # 显示连通区域块数(从0开始标记)
        #
        # dst = color.label2rgb(labels)  # 根据不同的标记显示不同的颜色
        # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
        # ax1.imshow(image, plt.cm.gray, interpolation='nearest')
        # ax1.axis('off')
        # ax2.imshow(dst, interpolation='nearest')
        # ax2.axis('off')
        # fig.tight_layout()
        # plt.show()


        area = 9999999999
        targetLable = 0
        for i in range(1,labels.max()+1,1):
            areaTemp = np.sum((labels == i) * 1)
            print(i,areaTemp,labels.max())
            if areaTemp < area:
                area = areaTemp
                targetLable = i
        if i == 0:
            raise ('error')

        cheweizhongjianImage = (labels == targetLable) * 255
        cheweibianxianImage = self.erzhiTu(pic,True)
        finalImage = cheweibianxianImage+cheweizhongjianImage


        [x,y] = self.caculateMean(finalImage)
        cv2.circle(oriImage, (int(x), int(y)), 16, (0, 255, 255), 2)

        image = self.toushiAndPyr(pic,True)

        #获取左右视野宽度
        hangImg = image[int(y)]
        zuoPix = hangImg.shape[0]
        youPix = 0

        for i in range(int(hangImg.shape[0])):
            if hangImg[i] != 0:
                if i > youPix:
                    youPix = i
                if i < zuoPix:
                    zuoPix = i
                #print(i,hangImg[i])

        print(zuoPix,youPix)

        x_bias = 0
        if x>hangImg.shape[0]/2:
            x_bias = (x-hangImg.shape[0]/2)/(youPix-hangImg.shape[0]/2)
        else:
            x_bias = - (x - hangImg.shape[0] / 2) / (zuoPix - hangImg.shape[0] / 2)

        y_bias = 0
        y_bias = - (y - finalImage.shape[1]/2) / (finalImage.shape[1]/2)
        #由于坐标y轴反方向，此处添加负号

        print('cheweiMean:',x_bias,y_bias)
        return [x_bias,y_bias]




    def whiteMean(self,pic):
        erzhiImg = self.erzhiTu(pic)
        oriImage = self.toushiAndPyr(pic)
        zongshu = np.sum(erzhiImg)
        print("zhongshu/:%s"%zongshu)

        print('输入图像shape：%s'%str(erzhiImg.shape))
        hang = erzhiImg.shape[0]
        erzhiImgTemp = np.multiply(erzhiImg,np.arange(1,hang+1).reshape(1,hang))
        x_ = np.sum(erzhiImgTemp,dtype=np.uint64) / zongshu

        lie = erzhiImg.shape[1]
        erzhiImgTemp = np.multiply(erzhiImg,np.arange(1,lie+1).reshape(lie,1))
        y_ = np.sum(erzhiImgTemp,dtype=np.uint64) / zongshu

        print('结果值：（%s，%s）'%(x_,y_))

        cv2.circle(oriImage, (int(x_), int(y_)), 16, (0, 255, 255), 2)
        self.show([oriImage,self.dilate(erzhiImg)],True)

    def caculateMean(self,image):
        erzhiImg = image

        zongshu = np.sum(erzhiImg)

        hang = erzhiImg.shape[0]
        erzhiImgTemp = np.multiply(erzhiImg,np.arange(1,hang+1).reshape(1,hang))
        x_ = np.sum(erzhiImgTemp,dtype=np.uint64) / zongshu

        lie = erzhiImg.shape[1]
        erzhiImgTemp = np.multiply(erzhiImg,np.arange(1,lie+1).reshape(lie,1))
        y_ = np.sum(erzhiImgTemp,dtype=np.uint64) / zongshu

        print('结果值：（%s，%s）'%(x_,y_))
        return [x_,y_]



    def getPic(self,pic,otherSource = True):
        self.otherSource = otherSource


        print("autpAgent::getPic count:")
        print(type(pic))
        image = Image.fromarray(pic)

        #测试时不需要转换角度
        if not otherSource:       
            image = image.transpose(Image.ROTATE_270)

        o_image = image
        #透视转换
        self.startTime = time.time()
        self.startTime1 = time.time()
        image = self.toushiTrans(np.asarray(image),show=False,save=False)
        self.showTime('toushi')
        self.save(image,'toushi')
        self.show([pic,image])
        oriImage = cv2.pyrDown(image)


        #灰度转换
        self.time = time.time()       
        image= self.color2gray(image,show=False,save=False)
        self.showTime('huidu')
        self.save(image,'huidu')

        #pooling
        #image = self.pooling(image,4,4)
        image = cv2.pyrDown(image)
        #image = cv2.pyrDown(image)
        self.showTime('pool')
        self.save(image,'pool')
        #oriImage = image
        self.picShape = image.shape

        #K聚类
        #image = self.Kmeans(image,4)
        image = (image > 230)*1
        erzhiImg = np.uint8(image*255)
        self.show([np.uint8(image*255)])
        self.showTime('Kmeans')
        self.save(image,'Kmeans')

        #骨架生成
        s_image = self.skeletonGene(image > 0)
        image = s_image
        self.showTime('skeletonGene')
        self.show([image])
        self.save(image,'skeletonGene')

        #霍夫变换
        image = np.int16(image)
        #lines = self.lines_detector_hough(image)

        #_____________________
        image = np.uint8(image)
        lines = cv2.HoughLines(image,1,np.pi/90,50)
        if type(lines) == None or type(lines) == type(None):
            return
        print('lines type:%s'%type(lines))
        print('lines.shape:[%s]'%str(lines.shape))
        lines = lines[:,0,:]
        lines = lines.T
        p = lines[0].copy()
        lines[0] = lines[1]
        lines[1] = p
        #______________________


        #交点获取
        [xs,ys] = self.getIntersection(lines,oriImage)
        self.showTime('houghImg')

        ms = MeanShift(bandwidth=100, bin_seeding=True)
        if len(xs) != 0:
            ms.fit(np.array([xs,ys]).T)
            cluster_centers = ms.cluster_centers_
            xs = cluster_centers.T[0]
            ys = cluster_centers.T[1]

        print('len(xs) is %d'%len(xs))
        for i in range(len(xs)):
            print('the postion is x:[%s] y:[%s]'%(xs[i],ys[i]))   

        self.startTime = self.startTime1
        self.showTime('all')


        oriPic = np.uint8(oriImage)
        if not len(xs) == 4:

            # print('len(xs) is %d,start image display'%len(xs))
            # oriImage = np.uint8(oriImage)
            # yuantu = oriImage.copy()
            # #for theta,rho in lines.T:
            # for i in range(lines.shape[1]):
            #     #x y 是反的 原因未知，以下按照相反处理
            #     theta = lines[0][i]
            #     rho = lines[1][i]
            #     a = np.cos(theta)
            #     b = np.sin(theta)
            #     x0 = a*rho
            #     y0 = b*rho
            #
            #     x1 = int(x0 + 10000*(-b))
            #     y1 = int(y0 + 10000*(a))
            #     x2 = int(x0 - 10000*(-b))
            #     y2 = int(y0 - 10000*(a))
            #     #print('cv2::drawLine theta:[%s] rho:[%s] 0point:[%d,%d] 1point:[%d,%d] 2point:[%d,%d]'%(theta,rho,x0,y0,x1,y1,x2,y2))
            #     if self.cv2hough:
            #         cv2.line(oriImage,(x1,y1),(x2,y2),(255,255,0),2)
            #     else:
            #         cv2.line(oriImage,(y1,x1),(y2,x2),(255,255,0),2)
            #
            # for i in range(len(xs)):
            #     cv2.circle(oriImage, (int(xs[i]),int(ys[i])), 16, (0,255,255), 2)
            #
            # self.save(o_image,'errImage%s'%time.time(),True)
            # self.show([oriImage,s_image],True)
            # # self.show([oriImage,s_image,erzhiImg,yuantu],True)
            # for i in range(len(xs)):
            #     print('the postion is x:[%s] y:[%s]'%(xs[i],ys[i]))

            return None


        if len(xs) == 4:
            [x,y,angle] = self.posGetter.getPostion(xs,ys,oriImage.shape[1])
            print([x,y,angle])
			
        return [x,y+14,angle]


    def save(self,image,name = None,save = False):
        if not self.savePic and not save or self.otherSource:
            return

        self.log('save image name:%s  imgType:%s'%(str(name),type(image)))
        if name == None:
            name = time.time()

        if type(image) == np.ndarray:
            self.log('the image shape is :%s'%str(image.shape))
            image = Image.fromarray(image)

        if image.mode != 'RGB':
            image = image.convert('RGB')

        image.save('out/%s.png'%name)

        #cv2.imwrite('_out/%s.png'%name,image)           


    def log(self,log):
        if self.logging:
            print(log)

    def showTime(self,runFunction = "unKnow"):
        if self.showTimeBool:
            runTime = time.time() - self.startTime
            print('%s:runtime:%s'%(runFunction,runTime))
            self.startTime = time.time()        


    def lines_detector_hough(self,edge,ThetaDim = None,DistStep = None,threshold = None,halfThetaWindowSize = 2,halfDistWindowSize = None):
        '''
        :param edge: 经过边缘检测得到的二值图
        :param ThetaDim: hough空间中theta轴的刻度数量(将[0,pi)均分为多少份),反应theta轴的粒度,越大粒度越细
        :param DistStep: hough空间中dist轴的划分粒度,即dist轴的最小单位长度
        :param threshold: 投票表决认定存在直线的起始阈值
        :return: 返回检测出的所有直线的参数(theta,dist)
        @author: bilibili-会飞的吴克
        '''
        print('start hough Trans  the shape is %s'%str(edge.shape))
        imgsize = edge.shape
        if ThetaDim == None:
            ThetaDim = 720
        if DistStep == None:
            DistStep = 2
        MaxDist = np.sqrt(imgsize[0]**2 + imgsize[1]**2)
        DistDim = int(np.ceil(MaxDist/DistStep))

        if halfDistWindowSize == None:
            halfDistWindowSize = DistDim/50
        accumulator = np.zeros((ThetaDim,DistDim)) # theta的范围是[0,pi). 在这里将[0,pi)进行了线性映射.类似的,也对Dist轴进行了线性映射

        sinTheta = [np.sin(t*np.pi/ThetaDim) for t in range(ThetaDim)]
        cosTheta = [np.cos(t*np.pi/ThetaDim) for t in range(ThetaDim)]

        for i in range(imgsize[0]):
            for j in range(imgsize[1]):
                if not edge[i,j] == 0:
                    for k in range(ThetaDim):
                        accumulator[k][int(round((i*cosTheta[k]+j*sinTheta[k])*DistDim/MaxDist))] += 1

        M = accumulator.max()

        #——————————————————————————
        #image = accumulator / M * 255
        #self.show([image],True)
        #——————————————————————————

        if threshold == None:
            threshold = int(M*0.2)
        result = np.array(np.where(accumulator > threshold)) # 阈值化
        
        temp = [[],[]]
        noMaxScope = 2
        for i in range(result.shape[1]):
            maxNum1 = int(max(0, result[0,i] - halfThetaWindowSize * noMaxScope + 1))
            maxNum2 = int(max(0, result[1,i] - halfDistWindowSize * noMaxScope + 1))
            minNum1 = int(min(result[0,i] + halfThetaWindowSize * noMaxScope, accumulator.shape[0]))
            minNum2 = int(min(result[1,i] + halfDistWindowSize * noMaxScope, accumulator.shape[1]))

            eight_neiborhood = accumulator[maxNum1:minNum1, maxNum2:minNum2]
            if (accumulator[result[0,i],result[1,i]] >= eight_neiborhood).all():
                temp[0].append(result[0,i])
                temp[1].append(result[1,i])
        
        result = np.array(temp)    # 非极大值抑制
        
        result = result.astype(np.float64)
        result[0] = result[0]*np.pi/ThetaDim
        result[1] = result[1]*MaxDist/DistDim
        #print("hough lines:",result)
        return result

    def getIntersection(self,lines,img = ''):
        #print(lines.shape)
        Cos = np.cos(lines[0])
        Sin = np.sin(lines[0])
        theta = lines[0]
        r = lines[1]
         
        #print(r)
        #print(theta)
        xs = []
        ys = []
        self.log('getIntersection::there is %s line to compute'%lines.shape[1])        
        for i in range(lines.shape[1]):
            for j in range(i + 1,lines.shape[1]):
                #print("i:%s j:%s"%(i,j))

                base = Cos[i] * Sin[j] - Cos[j] * Sin[i]
                if base == 0:
                    #print('base is 0')
                    continue
                
                #print('theta1:%s theta2:%s %s'%(theta[i],theta[j],np.abs(theta[i] - theta[j]) < 0.75))
                if np.abs(theta[i] - theta[j]) < 1 or np.abs(theta[i] - theta[j]) > 2:
                    continue
                
                xUp = r[i] * Sin[j] - r[j] * Sin[i]
                yUp = r[j] * Cos[i] - r[i] * Cos[j]
                #print('cos[i] : %s  sin[j]:%s cos[j]:%s sin[i]:%s'%(Cos[i],Sin[j],Cos[j],Sin[i]))
                x = xUp / base
                y = yUp / base
                #print('base : %s  xUp:%s  yUp:%s x= %s  y = %s'%(base,xUp,yUp,x,y))
                
                if self.outside(x,y):
                    continue

                xs.append(round(x))
                ys.append(round(y))
                
                if False:
                    lines0 = np.array([
                        [theta[i],theta[j]],
                        [r[i],r[j]]
                    ])
                    point = np.array([
                        [x],[y]
                    ])
                    self.drawLineAndPoint(img,lines0,point)
                
        #print(xs,ys)
        xs = np.array(xs)
        ys = np.array(ys)
        #与cv2坐标不同 返回时进行置换
        if not self.cv2hough:
            return [ys,xs]
        else:
            return [xs,ys]

    def drawLineAndPoint(self,img,lines,point):                
        oriImage = np.uint8(img).copy()
        [xs,ys] = point
        #for theta,rho in lines.T:
        for i in range(lines.shape[1]):
            #x y 是反的 原因未知，以下按照相反处理
            theta = lines[0][i]
            rho = lines[1][i]
            a = np.cos(theta)
            b = np.sin(theta)
            x0 = a*rho
            y0 = b*rho
            
            x1 = int(x0 + 10000*(-b))
            y1 = int(y0 + 10000*(a))
            x2 = int(x0 - 10000*(-b))
            y2 = int(y0 - 10000*(a))
            print('cv2::drawLine theta:[%s] rho:[%s] 0point:[%d,%d] 1point:[%d,%d] 2point:[%d,%d]'%(theta,rho,x0,y0,x1,y1,x2,y2))
            if self.cv2hough:
                cv2.line(oriImage,(x1,y1),(x2,y2),(255,255,255),2)
            else:  
                cv2.line(oriImage,(y1,x1),(y2,x2),(255,255,255),2)

        for i in range(len(xs)):
            cv2.circle(oriImage, (int(xs[i]),int(ys[i])), 8, (255,255,255), 2)

        self.show([oriImage],True)    
    

    def drawLines(self,lines,edge,color = (255,255,255),err = 2,showIntersection = False):
        if len(edge.shape) == 2:
            result = np.dstack((edge,edge,edge))
        else:
            result = edge
        Cos = np.cos(lines[0])
        Sin = np.sin(lines[0])
        [xs,ys] = self.getIntersection(lines)

        for i in range(edge.shape[0]):
            for j in range(edge.shape[1]):
                e = np.abs(lines[1] - i*Cos - j*Sin)
                if (e < err).any():
                    result[i,j] = color

                if showIntersection:
                    e = np.square(xs - i) + np.square(ys - j)
                    if (e < err*err*4).any():
                        result[i,j] = (255,0,0)

        '''
        for i in range(len(xs)):
            result[int(xs[i]),int(ys[i])] = (255,0,0)
        '''
        return result






    def exit(self):
        print('autoAgent::exit')
        if self.ex:
            return

        if self.world != None:
            self.processer.exit()
        exit(0)

    def color2gray(self, img, show=False, save=False):
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)



        if show:
            print('gray_img::show')
            plt.subplot(121), plt.imshow(img), plt.title('Input')
            plt.subplot(122), plt.imshow(gray_img), plt.title('Output')
            plt.show()
        if save:
            print('gray_img::save')
            io.imsave('gray_img.png', gray_img)
        return gray_img
        

    def toushiTrans(self, img, show=False, save=False,degree = 20):
        print(('toushiTrans show:%s save:%s')%(show,save))
        #pts1 = np.float32([[304, 368], [735, 370], [137, 456], [783, 456]])
        #pts2 = np.float32([[100, 100], [800, 100], [100, 450], [800, 450]])
        #pts2 += 200
        #pts2 += [170, 4000]
        if degree == 90:
            pts1 = np.float32([[275, 278], [338, 281], [279, 232], [339, 233]])
            pts2 = np.float32([[600, 700], [700, 700], [600, 600], [700, 600]])
            pts2 += 300
        elif degree == 20:
            pts1 = np.float32([[96, 389], [488, 398], [167, 153], [422, 159]])
            pts2 = np.float32([[557, 2840], [1492, 2845], [508, 1500], [1443, 1505]])
            pts2 += np.float32([1000,800])

        M = cv2.getPerspectiveTransform(pts1, pts2)

        dst = cv2.warpPerspective(img, M, (4000, 4000))


        if show:
            print('toushiTrans::show')
            plt.subplot(121), plt.imshow(img), plt.title('Input')
            plt.subplot(122), plt.imshow(dst), plt.title('Output')
            plt.show()
        if save:
            print('toushiTrans::save')
            io.imsave('out.png', dst)

        return dst

    def Kmeans(self,img,k):
        print('start Kmeans k = ',k)
        #centers = np.random.rand(1,k) * 255
        centers = (np.arange(0,k,1) * 255 / k).reshape(1,k)
        totalNum = img.shape[0]*img.shape[1]
        vector = img.reshape([totalNum,1]).copy()
        master = np.zeros([totalNum,1])
        iters = 0
        loss = 1
        cost = 99999999999999999
        while iters < 100 and not loss == 0:
            iters += 1
            dist = np.abs(centers - vector)
            master = np.argmin(dist,axis=1)
            lastCost = cost
            cost = np.sum(dist)
            loss = lastCost - cost


            for i in range(k):
                indexs = np.array(list(np.where(master == i)))
                indexs = indexs.reshape(indexs.shape[1])
                belongPoints = vector[indexs]
                centers[0,i] = np.mean(belongPoints)

            #print('iter:%s cost:%s loss:%s'%(iters,cost,loss))
        
        needIndex = np.argmax(centers,axis=1)[0]
        #print(needIndex.shape,'  needIndex')
        imgs = []
        for i in range(4):
            indexs = np.array(list(np.where(master == i)))
            vector = np.zeros([totalNum,1])
            vector[indexs] = 255
            imgBuffer = vector.reshape(img.shape)
            imgs.append(imgBuffer)

        img0 = imgs[0]
        imgs[0] = imgs[needIndex]
        imgs[needIndex] = img0

        self.show(imgs)




        
        return imgs[0]

    def show(self,images,show = False):
        if show or self.showPic:       
            print('show picture num = %s'%len(images))
            if len(images) == 2:
                plt.subplot(121), plt.imshow(images[0]), plt.title('Input')
                plt.subplot(122), plt.imshow(images[1]), plt.title('Output')
            elif len(images) == 4:
                plt.subplot(221), plt.imshow(images[0]), plt.title('0')
                plt.subplot(222), plt.imshow(images[1]), plt.title('1')            
                plt.subplot(223), plt.imshow(images[2]), plt.title('2')
                plt.subplot(224), plt.imshow(images[3]), plt.title('3')
            elif len(images) == 1:
                plt.imshow(images[0])
            else:
                print('can`t show  number is wrong')
                return           
            plt.show()

    def pooling(self,inputMap,poolSize=3,poolStride=2,mode='max'):
        """INPUTS:
                inputMap - input array of the pooling layer
                poolSize - X-size(equivalent to Y-size) of receptive field
                poolStride - the stride size between successive pooling squares
        
        OUTPUTS:
                outputMap - output array of the pooling layer
                
        Padding mode - 'edge'
        """
        # inputMap sizes
        in_row,in_col = np.shape(inputMap)
        
        # outputMap sizes
        out_row,out_col = int(np.floor(in_row/poolStride)),int(np.floor(in_col/poolStride))
        row_remainder,col_remainder = np.mod(in_row,poolStride),np.mod(in_col,poolStride)
        if row_remainder != 0:
            out_row +=1
        if col_remainder != 0:
            out_col +=1
        outputMap = np.zeros((out_row,out_col))
        
        # padding
        temp_map = np.lib.pad(inputMap, ((0,poolSize-row_remainder),(0,poolSize-col_remainder)), 'edge')
        
        # max pooling
        for r_idx in range(0,out_row):
            for c_idx in range(0,out_col):
                startX = c_idx * poolStride
                startY = r_idx * poolStride
                poolField = temp_map[startY:startY + poolSize, startX:startX + poolSize]
                poolOut = np.max(poolField)
                outputMap[r_idx,c_idx] = poolOut
        
        #self.show([inputMap,outputMap])

        # retrun outputMap
        return  outputMap

    def skeletonGene(self,img):
        skeleton =morphology.skeletonize(img)
        self.show([img,skeleton])
        return skeleton





        


    